import numpy as np
import torch
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import TensorBoardOutputFormat
import matplotlib.pyplot as plt

class SummaryWriterCallback(BaseCallback):

    def _on_training_start(self):
        self._log_freq = 10  # log every 1000 calls        
        output_formats = self.logger.output_formats
        self.last_episode = []
        self.last_ep_len = []
        # Save reference to tensorboard formatter object
        # note: the failure case (not formatter found) is not handled here, should be done with try/except.
        self.tb_formatter = next(formatter for formatter in output_formats if isinstance(formatter, TensorBoardOutputFormat))

    def _on_step(self) -> bool:

            self.tb_formatter.writer.add_scalar("actions/delX env #1",
                                         self.locals['infos'][0]['action'][0],
                                         self.n_calls)
            self.tb_formatter.writer.add_scalar("actions/delY env #1",
                                         self.locals['infos'][0]['action'][1],
                                         self.n_calls)

            # for i,infos in enumerate(self.locals["infos"]):
            #     if infos["terminal"]: 
            #         terminal_states = np.asarray(infos["robot_pose"])
            #         plt.scatter(terminal_states[:,0],terminal_states[:,1])
            #         # self.tb_formatter.writer.add_image(plt.gcf())
            #         self.tb_formatter.writer.add_image("robot_end_poses/env_{}".format(i),plt.gcf(),self.n_calls)

            # if self.n_calls % 149 == 0 and self.n_calls > 10000:
            #     global_step = int(self.n_calls/149)
            #     desired = []
            #     achieved = []
            #     for i in range(self.locals['env'].num_envs):
            #         desired.extend(self.locals['infos'][i]['desired_pos'])
            #         achieved.extend(self.locals['infos'][i]['achieved_pos'])
            #     meshtensor = torch.tensor([desired])
            #     points =  torch.concat((torch.tensor([desired]),torch.tensor([achieved])),axis=1)
            #     colors = torch.concat((torch.ones(meshtensor.shape)*torch.tensor([255,0,0]),
            #                         torch.ones(meshtensor.shape)*torch.tensor([0,255,0])), axis=1)

            #     self.tb_formatter.writer.add_mesh("expected(red) v. actual(green)",
            #                                     points,
            #                                     colors=colors,
            #                                     global_step=global_step,
            #                                     )

            # self.tb_formatter.writer.flush()

            if self.last_episode == []:
                self.last_episode = [0]*self.locals['env'].num_envs
                self.last_ep_len = [0]*self.locals['env'].num_envs
            for i in range(self.locals['env'].num_envs):
                self.tb_formatter.writer.add_scalar("step rewards/env #{}".format(i+1),
                                                    self.locals['rewards'][i],
                                                    self.n_calls)
                self.tb_formatter.writer.add_scalar("expected-actual/X",
                                         self.locals['infos'][i]['pos_difference'][0],
                                         self.n_calls)
                self.tb_formatter.writer.add_scalar("expected-actual/Y",
                                         self.locals['infos'][i]['pos_difference'][1],
                                         self.n_calls)
                self.tb_formatter.writer.flush()

                if self.n_calls and self.n_calls % 200 == 0:
                    if len(self.locals['infos'][i]['episode_reward']) > self.last_episode[i]:
                        numeps = len(self.locals['infos'][i]['episode_reward']) - self.last_episode[i]
                        step = self.last_ep_len[i]
                        for ep in range(numeps):
                            
                            ep_reward = self.locals['infos'][i]['episode_reward'][self.last_episode[i]+ep][1]
                            ep_len = self.locals['infos'][i]['episode_reward'][self.last_episode[i]+ep][0]+step
                            self.tb_formatter.writer.add_scalar("episodic/env #{}".format(i+1),
                                                                ep_reward,
                                                                ep_len)
                            step = ep_len
                            self.tb_formatter.writer.flush()

                        self.last_episode[i]+=numeps
                        self.last_ep_len[i]=step

                if self.n_calls % self._log_freq == 0:
                    self.tb_formatter.writer.add_scalar("cum rewards/env #{}".format(i+1),
                                                        self.locals['infos'][i]['cumreward'],
                                                        self.n_calls)
                    self.tb_formatter.writer.flush()

            return True